{% extends 'base.exported.class' %}

{% block content %}
var fs = require('fs');


var {{ class_name }} = function(jsonFile) {
    this.data = function(data) {
        data.hidden_activation = data.hidden_activation.toUpperCase();
        data.output_activation = data.output_activation.toUpperCase();
        data.network = new Array(data.layers.length + 1);
        for (var i = 0, l = data.layers.length; i < l; i++) {
            data.network[i + 1] = new Array(data.layers[i]).fill(0.);
        }
        return data;
    }(JSON.parse(fs.readFileSync(jsonFile)));

    this._resetNetwork = function() {
        for (var i = 1, l = this.data.network.length - 1; i < l; i++) {
            for (var j = 0; j < this.data.network[i].length; j++) {
                this.data.network[i][j] = 0;
            }
        }
    };

    var _findMax = function(nums) {
        var idx = 0;
        for (var i = 0, l = nums.length; i < l; i++) {
            idx = nums[i] > nums[idx] ? i : idx;
        }
        return idx;
    };

    var _compute = function(activation, v) {
        var i, l = v.length;
        switch (activation) {
            case 'LOGISTIC':
                for (i = 0; i < l; i++) {
                    v[i] = 1. / (1. + Math.exp(-v[i]));
                }
                break;
            case 'RELU':
                for (i = 0; i < l; i++) {
                    v[i] = Math.max(0, v[i]);
                }
                break;
            case 'TANH':
                for (i = 0; i < l; i++) {
                    v[i] = Math.tanh(v[i]);
                }
                break;
            case 'SOFTMAX':
                var max = Number.NEGATIVE_INFINITY;
                for (i = 0; i < l; i++) {
                    if (v[i] > max) {
                        max = v[i];
                    }
                }
                for (i = 0; i < l; i++) {
                    v[i] = Math.exp(v[i] - max);
                }
                var sum = 0.0;
                for (i = 0; i < l; i++) {
                    sum += v[i];
                }
                for (i = 0; i < l; i++) {
                    v[i] /= sum;
                }
                break;
        }
        return v;
    };

    this.feedForward = function(neurons) {
        this.data.network[0] = neurons;
        for (var i = 0; i < this.data.network.length - 1; i++) {
            for (var j = 0; j < this.data.network[i + 1].length; j++) {
                this.data.network[i + 1][j] = this.data.bias[i][j];
                for (var l = 0; l < this.data.network[i].length; l++) {
                    this.data.network[i + 1][j] += this.data.network[i][l] * this.data.weights[i][l][j];
                }
            }
            if ((i + 1) < (this.data.network.length - 1)) {
                this.data.network[i + 1] = _compute(this.data.hidden_activation, this.data.network[i + 1]);
            }
        }
        this.data.network[this.data.network.length - 1] = _compute(this.data.output_activation, this.data.network[this.data.network.length - 1]);
        this._resetNetwork();
        return this.data.network[this.data.network.length - 1];
    };

    this.predictProba = function(neurons) {
        var lastLayer = this.feedForward(neurons);
        if (lastLayer.length === 1) {
            return [lastLayer[0], 1 - lastLayer[0]];
        } else {
            return lastLayer;
        }
    };

    this.predict = function(neurons) {
        var lastLayer = this.feedForward(neurons);
        if (lastLayer.length === 1) {
            if (lastLayer[0] > .5) {
                return 1;
            }
            return 0;
        } else {
            return _findMax(lastLayer);
        }
    };
};

var main = function () {
    // Features:
    var features = process.argv.slice(3);
    for (var i = 0; i < features.length; i++) {
        features[i] = parseFloat(features[i]);
    }

    // Model data:
    var json = process.argv[2];

    // Estimator:
    var clf = new {{ class_name }}(json);

    {% if is_test or to_json %}
    // Get JSON:
    console.log(JSON.stringify({
        "predict": clf.predict(features),
        "predict_proba": clf.predictProba(features)
    }));
    {% else %}
    // Get class prediction:
    var prediction = clf.predict(features);
    console.log("Predicted class: #" + prediction);

    // Get class probabilities:
    var probabilities = clf.predictProba(features);
    for (var i = 0; i < probabilities.length; i++) {
        console.log("Probability of class #" + i + " : " + probabilities[i]);
    }
    {% endif %}
}

if (require.main === module) {
    main();
}
{% endblock %}